{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Solution for exercise\n",
    "---------\n",
    "\n",
    "```python\n",
    "def message_func(edges):\n",
    "    return {'m' : edges.src['h'] * edges.src['norm']}\n",
    "\n",
    "reduce_func = fn.sum('m', 'h')\n",
    "\n",
    "class GraphConv(gluon.Block):\n",
    "    def __init__(self, in_feats, out_feats):\n",
    "        super(GraphConv, self).__init__()\n",
    "        self.linear = nn.Dense(out_feats)\n",
    "    \n",
    "    def forward(self, g, inputs):\n",
    "        # g is the graph and the inputs is the input node features\n",
    "        # calculate norm = 1/sqrt(degree)\n",
    "        norm = 1.0 / nd.sqrt(g.in_degrees(g.nodes()).astype('float32')).reshape(-1, 1)\n",
    "        g.ndata['norm'] = norm\n",
    "        # perform linear transformation\n",
    "        h = self.linear(inputs)\n",
    "        # set the node features\n",
    "        g.ndata['h'] = h\n",
    "        # trigger message passing, using the defined message_func and reduce_func.\n",
    "        g.update_all(message_func, reduce_func)\n",
    "        # get the result node features\n",
    "        h = g.ndata.pop('h') * norm\n",
    "        return h\n",
    "\n",
    "inputs = nd.eye(34)  # featureless inputs\n",
    "labeled_nodes = nd.array([0, 33])  # only the instructor and the president nodes are labeled\n",
    "labels = nd.array([0, 1])  # their labels are different\n",
    "\n",
    "net = GCN(34, 5, 2)\n",
    "net.initialize()\n",
    "trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': 0.05})\n",
    "loss_fn = gluon.loss.SoftmaxCELoss()\n",
    "\n",
    "all_logits = []\n",
    "for epoch in range(30):\n",
    "    with autograd.record():\n",
    "        logits = net(G, inputs)\n",
    "        # we only compute loss for node 0 and node 33\n",
    "        loss = loss_fn(logits[labeled_nodes], labels).sum()\n",
    "    all_logits.append(logits.detach())\n",
    "    \n",
    "    loss.backward()\n",
    "    trainer.step(batch_size=1)\n",
    "    \n",
    "    print('Epoch %d | Loss: %.4f' % (epoch, loss.asscalar()))\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
